{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Comparison of Planar, Radial, and Affine Coupling Flows"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we train normalizing flows to fit predefined prior distributions, testing their expressivity. The plots are generated to visualize the learned distributions for given layers $K$, and the training loss is plotted to compare the expressivity of different flows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# Import required packages\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "import normflows as nf\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "print(\"PyTorch version: %s\" % torch.__version__)\n",
    "dev = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "print(\"Using device: %s\" % dev)\n",
    "\n",
    "#z shape is (batch_size, num_samples, dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prior target distribution visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "priors = []\n",
    "priors.append(nf.distributions.TwoModes(2.0, 0.2))\n",
    "priors.append(nf.distributions.Sinusoidal(0.4, 4))\n",
    "priors.append(nf.distributions.Sinusoidal_gap(0.4, 4))\n",
    "priors.append(nf.distributions.Sinusoidal_split(0.4, 4))\n",
    "priors.append(nf.distributions.Smiley(0.15))\n",
    "\n",
    "\n",
    "# Plot prior distributions\n",
    "grid_size = 200\n",
    "grid_length = 4.0\n",
    "grid_shape = ([-grid_length, grid_length], [-grid_length, grid_length])\n",
    "\n",
    "space_mesh = torch.linspace(-grid_length, grid_length, grid_size)\n",
    "xx, yy = torch.meshgrid(space_mesh, space_mesh)\n",
    "z = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2)\n",
    "z = z.reshape(-1, 2)\n",
    "\n",
    "K_arr = [2, 8, 32]\n",
    "max_iter = 30000\n",
    "batch_size = 512\n",
    "num_samples = 256\n",
    "save_iter = 1000\n",
    "\n",
    "for k in range(len(priors)):\n",
    "    log_prob = priors[k].log_prob(z)\n",
    "    prob = torch.exp(log_prob)\n",
    "\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    plt.pcolormesh(xx, yy, prob.reshape(grid_size, grid_size))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Flow training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "flow_types = (\"Planar\", \"Radial\", \"NICE\", \"RealNVP\")\n",
    "max_iter = 20000\n",
    "batch_size = 1024\n",
    "plot_batches = 10 ** 2\n",
    "plot_samples = 10 ** 4\n",
    "save_iter = 50\n",
    "\n",
    "for name in flow_types:\n",
    "    K_arr = [2, 8, 32]\n",
    "    for K in K_arr:\n",
    "        print(\"Flow type {} with K = {}\".format(name, K))\n",
    "        for k in range(len(priors)):\n",
    "            if k == 0 or k == 4:\n",
    "                anneal_iter = 10000\n",
    "            else: # turn annealing off when fitting to sinusoidal distributions\n",
    "                anneal_iter = 1\n",
    "        \n",
    "            flows = []\n",
    "            b = torch.tensor([0,1])\n",
    "            for i in range(K):\n",
    "                if name == \"Planar\":\n",
    "                    flows += [nf.flows.Planar((2,))]\n",
    "                elif name == \"Radial\":\n",
    "                    flows += [nf.flows.Radial((2,))]\n",
    "                elif name == \"NICE\":\n",
    "                    flows += [nf.flows.MaskedAffineFlow(b, nf.nets.MLP([2, 16, 16, 2], init_zeros=True))]\n",
    "                elif name == \"RealNVP\":\n",
    "                    flows += [nf.flows.MaskedAffineFlow(b, nf.nets.MLP([2, 16, 16, 2], init_zeros=True), \n",
    "                                                        nf.nets.MLP([2, 16, 16, 2], init_zeros=True))]\n",
    "                b = 1-b # parity alternation for mask\n",
    "\n",
    "            q0 = nf.distributions.DiagGaussian(2)\n",
    "            nfm = nf.NormalizingFlow(p=priors[k], q0=q0, flows=flows)\n",
    "            nfm.to(dev) # Move model on GPU if available\n",
    "    \n",
    "            # Train model\n",
    "            loss_hist = np.array([])\n",
    "            log_q_hist = np.array([])\n",
    "            log_p_hist = np.array([])\n",
    "            x = torch.zeros(batch_size, device=dev)\n",
    "\n",
    "            optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-3, weight_decay=1e-3)\n",
    "            for it in tqdm(range(max_iter)):\n",
    "                optimizer.zero_grad()\n",
    "                loss = nfm.reverse_kld(batch_size, np.min([1.0, 0.01 + it / anneal_iter]))\n",
    "                if ~(torch.isnan(loss) | torch.isinf(loss)):\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "            \n",
    "                if (it + 1) % save_iter == 0:\n",
    "                    loss_hist = np.append(loss_hist, loss.cpu().data.numpy())\n",
    "\n",
    "            # Plot learned posterior distribution\n",
    "            z_np = np.zeros((0, 2))\n",
    "            for i in range(plot_batches):\n",
    "                z, _ = nfm.sample(plot_samples)\n",
    "                z_np = np.concatenate((z_np, z.cpu().data.numpy()))\n",
    "            plt.figure(figsize=(10, 10))\n",
    "            plt.hist2d(z_np[:, 0], z_np[:, 1], (grid_size, grid_size), grid_shape)\n",
    "            plt.show()\n",
    "            np.save(\"{}-K={}-k={}\".format(name,K,k), (z_np, loss.cpu().data.numpy()))\n",
    "    \n",
    "            # Plot training history\n",
    "            plt.figure(figsize=(10, 10))\n",
    "            plt.plot(loss_hist, label='loss')\n",
    "            plt.legend()\n",
    "            plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Expressivity plot of flows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(14, 10))\n",
    "K_arr = [2, 8, 32]\n",
    "nrows=5\n",
    "ncols=7\n",
    "axes = [ fig.add_subplot(nrows, ncols, r * ncols + c + 1) for r in range(0, nrows) for c in range(0, ncols) ]\n",
    "\n",
    "for ax in axes:\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "grid_size = 100\n",
    "grid_length = 4.0\n",
    "grid_shape = ([-grid_length, grid_length], [-grid_length, grid_length])\n",
    "\n",
    "space_mesh = torch.linspace(-grid_length, grid_length, grid_size)\n",
    "xx, yy = torch.meshgrid(space_mesh, space_mesh)\n",
    "z = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2)\n",
    "z = z.reshape(-1, 2)\n",
    "axes[0].annotate('Target', xy=(0.5, 1.10), xytext=(0.5, 1.20), xycoords='axes fraction', \n",
    "            fontsize=24, ha='center', va='bottom',\n",
    "            arrowprops=dict(arrowstyle='-[, widthB=1.5, lengthB=0.2', lw=2.0))\n",
    "for k in range(5):\n",
    "    axes[k*ncols].set_ylabel('{}'.format(k+1), rotation=0, fontsize=20, labelpad=15)\n",
    "    log_prob = priors[k].log_prob(z)\n",
    "    prob = torch.exp(log_prob)\n",
    "    axes[k*ncols + 0].pcolormesh(xx, yy, prob.reshape(grid_size, grid_size))\n",
    "\n",
    "\n",
    "for l in range(len(K_arr)):\n",
    "    K = K_arr[l]\n",
    "    if l == 1:\n",
    "        axes[0*ncols + l+1].annotate('Planar flows', xy=(0.5, 1.10), xytext=(0.5, 1.20), xycoords='axes fraction', \n",
    "            fontsize=24, ha='center', va='bottom',\n",
    "            arrowprops=dict(arrowstyle='-[, widthB=6.0, lengthB=0.2', lw=2.0))\n",
    "    axes[4*ncols + l+1].set_xlabel('K = {}'.format(K), fontsize=20)\n",
    "    for k in range(5):\n",
    "        z_np, _ = np.load(\"Planar-K={}-k={}.npy\".format(K,k), allow_pickle=True)\n",
    "        axes[k*ncols + l+1].hist2d(z_np[:, 0], z_np[:, 1], (grid_size, grid_size), grid_shape)\n",
    "        \n",
    "for l in range(len(K_arr)):\n",
    "    K = K_arr[l]\n",
    "    if l == 1:\n",
    "        axes[0*ncols + l+1+len(K_arr)].annotate('Radial flows', xy=(0.5, 1.10), xytext=(0.5, 1.20), xycoords='axes fraction', \n",
    "            fontsize=24, ha='center', va='bottom',\n",
    "            arrowprops=dict(arrowstyle='-[, widthB=6.0, lengthB=0.2', lw=2.0))\n",
    "    axes[4*ncols + l+1+len(K_arr)].set_xlabel('K = {}'.format(K), fontsize=20)\n",
    "    for k in range(5):\n",
    "        z_np, _ = np.load(\"Radial-K={}-k={}.npy\".format(K,k), allow_pickle=True)\n",
    "        axes[k*ncols + l+1+len(K_arr)].hist2d(z_np[:, 0], z_np[:, 1], (grid_size, grid_size), grid_shape)\n",
    "\n",
    "fig.subplots_adjust(hspace=0.02, wspace=0.02)\n",
    "\n",
    "for l in range(1,4):\n",
    "    for k in range(5):\n",
    "        pos1 = axes[k*ncols + l].get_position() # get the original position \n",
    "        pos2 = [pos1.x0 + 0.01, pos1.y0,  pos1.width, pos1.height] \n",
    "        axes[k*ncols + l].set_position(pos2) # set a new position\n",
    "        \n",
    "for l in range(4,7):\n",
    "    for k in range(5):\n",
    "        pos1 = axes[k*ncols + l].get_position() # get the original position \n",
    "        pos2 = [pos1.x0 + 0.02, pos1.y0,  pos1.width, pos1.height] \n",
    "        axes[k*ncols + l].set_position(pos2) # set a new position\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison of Planar, Radial,  and Affine on given prior distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from itertools import repeat\n",
    "\n",
    "k_arr = [0, 2, 4]\n",
    "fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))\n",
    "markers = ['s', 'o', 'v', 'P', 'd']\n",
    "\n",
    "for k in range(len(k_arr)):\n",
    "    loss = [[] for i in repeat(None, len(flow_types))]\n",
    "    for intt, name in enumerate(flow_types):\n",
    "        for K in K_arr:\n",
    "            _, loss_v = np.load(\"{}-K={}-k={}.npy\".format(name,K,k), allow_pickle=True)\n",
    "            loss[intt].append(loss_v)\n",
    "        axes[k].plot(K_arr, loss[intt], marker=markers[intt], label=name)\n",
    "    axes[k].set_title('Target {}'.format(k_arr[k]+1), fontsize=16)\n",
    "    axes[k].set_xlabel('Flow length', fontsize=12)\n",
    "    axes[k].set_ylabel('Variational bound (nats)', fontsize=12)\n",
    "    axes[k].legend()\n",
    "    axes[k].grid('major')\n",
    "\n",
    "fig.tight_layout(pad=2.0)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
